{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# FrozenLake-v1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import logging\n",
    "import itertools\n",
    "\n",
    "import numpy as np\n",
    "np.random.seed(0)\n",
    "import gym\n",
    "\n",
    "logging.basicConfig(level=logging.DEBUG,\n",
    "        format='%(asctime)s [%(levelname)s] %(message)s',\n",
    "        stream=sys.stdout, datefmt='%H:%M:%S')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Use Environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "00:00:00 [INFO] observation space = Discrete(16)\n",
      "00:00:00 [INFO] action space = Discrete(4)\n",
      "00:00:00 [INFO] number of states = 16\n",
      "00:00:00 [INFO] number of actions = 4\n",
      "00:00:00 [INFO] P[14] = {0: [(0.3333333333333333, 10, 0.0, False), (0.3333333333333333, 13, 0.0, False), (0.3333333333333333, 14, 0.0, False)], 1: [(0.3333333333333333, 13, 0.0, False), (0.3333333333333333, 14, 0.0, False), (0.3333333333333333, 15, 1.0, True)], 2: [(0.3333333333333333, 14, 0.0, False), (0.3333333333333333, 15, 1.0, True), (0.3333333333333333, 10, 0.0, False)], 3: [(0.3333333333333333, 15, 1.0, True), (0.3333333333333333, 10, 0.0, False), (0.3333333333333333, 13, 0.0, False)]}\n",
      "00:00:00 [INFO] P[14][2] = [(0.3333333333333333, 14, 0.0, False), (0.3333333333333333, 15, 1.0, True), (0.3333333333333333, 10, 0.0, False)]\n",
      "00:00:00 [INFO] reward threshold = 0.7\n",
      "00:00:00 [INFO] max episode steps = 100\n"
     ]
    }
   ],
   "source": [
    "env = gym.make('FrozenLake-v1')\n",
    "env.seed(0)\n",
    "logging.info('observation space = %s', env.observation_space)\n",
    "logging.info('action space = %s', env.action_space)\n",
    "logging.info('number of states = %s', env.observation_space.n)\n",
    "logging.info('number of actions = %s', env.action_space.n)\n",
    "logging.info('P[14] = %s', env.unwrapped.P[14])\n",
    "logging.info('P[14][2] = %s', env.unwrapped.P[14][2])\n",
    "logging.info('reward threshold = %s', env.spec.reward_threshold)\n",
    "logging.info('max episode steps = %s', env.spec.max_episode_steps)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Play Using Random Policy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def play_policy(env, policy, render=False):\n",
    "    episode_reward = 0.\n",
    "    observation = env.reset()\n",
    "    while True:\n",
    "        if render:\n",
    "            env.render() # render the environment\n",
    "        action = np.random.choice(env.action_space.n,\n",
    "                p=policy[observation])\n",
    "        observation, reward, done, _ = env.step(action)\n",
    "        episode_reward += reward\n",
    "        if done:\n",
    "            break\n",
    "    return episode_reward"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "00:00:00 [INFO] ==== Random policy ====\n",
      "00:00:00 [INFO] average episode reward = 0.04 ± 0.20\n"
     ]
    }
   ],
   "source": [
    "logging.info('==== Random policy ====')\n",
    "random_policy = np.ones((env.observation_space.n,\n",
    "        env.action_space.n)) / env.action_space.n\n",
    "\n",
    "episode_rewards = [play_policy(env, random_policy)  for _ in range(100)]\n",
    "logging.info('average episode reward = %.2f ± %.2f',\n",
    "        np.mean(episode_rewards), np.std(episode_rewards))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluate Policy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def v2q(env, v, state=None, gamma=1.): # calculate action value from state value\n",
    "    if state is not None: # solve for single state\n",
    "        q = np.zeros(env.action_space.n)\n",
    "        for action in range(env.action_space.n):\n",
    "            for prob, next_state, reward, done in env.unwrapped.P[state][action]:\n",
    "                q[action] += prob * \\\n",
    "                        (reward + gamma * v[next_state] * (1. - done))\n",
    "    else: # solve for all states\n",
    "        q = np.zeros((env.observation_space.n, env.action_space.n))\n",
    "        for state in range(env.observation_space.n):\n",
    "            q[state] = v2q(env, v, state, gamma)\n",
    "    return q\n",
    "\n",
    "def evaluate_policy(env, policy, gamma=1., tolerant=1e-6):\n",
    "    v = np.zeros(env.observation_space.n) # initialize state values\n",
    "    while True:\n",
    "        delta = 0\n",
    "        for state in range(env.observation_space.n):\n",
    "            vs = sum(policy[state] * v2q(env, v, state, gamma)) # update state value\n",
    "            delta = max(delta, abs(v[state]-vs)) # update max error\n",
    "            v[state] = vs\n",
    "        if delta < tolerant: # check whether iterations can finish\n",
    "            break\n",
    "    return v"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Evaluate Random Policy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "00:00:00 [INFO] state value:\n",
      "[[0.0139372  0.01162942 0.02095187 0.01047569]\n",
      " [0.01624741 0.         0.04075119 0.        ]\n",
      " [0.03480561 0.08816967 0.14205297 0.        ]\n",
      " [0.         0.17582021 0.43929104 0.        ]]\n",
      "00:00:00 [INFO] action value:\n",
      "[[0.01470727 0.01393801 0.01393801 0.01316794]\n",
      " [0.00852221 0.01162969 0.01086043 0.01550616]\n",
      " [0.02444416 0.0209521  0.02405958 0.01435233]\n",
      " [0.01047585 0.01047585 0.00698379 0.01396775]\n",
      " [0.02166341 0.01701767 0.0162476  0.01006154]\n",
      " [0.         0.         0.         0.        ]\n",
      " [0.05433495 0.04735099 0.05433495 0.00698396]\n",
      " [0.         0.         0.         0.        ]\n",
      " [0.01701767 0.04099176 0.03480569 0.04640756]\n",
      " [0.0702086  0.11755959 0.10595772 0.05895286]\n",
      " [0.18940397 0.17582024 0.16001408 0.04297362]\n",
      " [0.         0.         0.         0.        ]\n",
      " [0.         0.         0.         0.        ]\n",
      " [0.08799662 0.20503708 0.23442697 0.17582024]\n",
      " [0.25238807 0.53837042 0.52711467 0.43929106]\n",
      " [0.         0.         0.         0.        ]]\n"
     ]
    }
   ],
   "source": [
    "v_random = evaluate_policy(env, random_policy)\n",
    "logging.info('state value:\\n%s', v_random.reshape(4, 4))\n",
    "\n",
    "q_random = v2q(env, v_random)\n",
    "logging.info('action value:\\n%s', q_random)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Improve Policy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def improve_policy(env, v, policy, gamma=1.):\n",
    "    optimal = True\n",
    "    for state in range(env.observation_space.n):\n",
    "        q = v2q(env, v, state, gamma)\n",
    "        action = np.argmax(q)\n",
    "        if policy[state][action] != 1.:\n",
    "            optimal = False\n",
    "            policy[state] = 0.\n",
    "            policy[state][action] = 1.\n",
    "    return optimal"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Improve from random policy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "00:00:00 [INFO] Updating completes. Updated policy is:\n",
      "[[1. 0. 0. 0.]\n",
      " [0. 0. 0. 1.]\n",
      " [1. 0. 0. 0.]\n",
      " [0. 0. 0. 1.]\n",
      " [1. 0. 0. 0.]\n",
      " [1. 0. 0. 0.]\n",
      " [1. 0. 0. 0.]\n",
      " [1. 0. 0. 0.]\n",
      " [0. 0. 0. 1.]\n",
      " [0. 1. 0. 0.]\n",
      " [1. 0. 0. 0.]\n",
      " [1. 0. 0. 0.]\n",
      " [1. 0. 0. 0.]\n",
      " [0. 0. 1. 0.]\n",
      " [0. 1. 0. 0.]\n",
      " [1. 0. 0. 0.]]\n"
     ]
    }
   ],
   "source": [
    "policy = random_policy.copy()\n",
    "optimal = improve_policy(env, v_random, policy)\n",
    "if optimal:\n",
    "    logging.info('No update. Optimal policy is:\\n%s', policy)\n",
    "else:\n",
    "    logging.info('Updating completes. Updated policy is:\\n%s', policy)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Iterate Policy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "00:00:00 [INFO] optimal state value =\n",
      "[[0.82351246 0.82350689 0.82350303 0.82350106]\n",
      " [0.82351416 0.         0.5294002  0.        ]\n",
      " [0.82351683 0.82352026 0.76469786 0.        ]\n",
      " [0.         0.88234658 0.94117323 0.        ]]\n",
      "00:00:00 [INFO] optimal policy =\n",
      "[[0 3 3 3]\n",
      " [0 0 0 0]\n",
      " [3 1 0 0]\n",
      " [0 2 1 0]]\n"
     ]
    }
   ],
   "source": [
    "def iterate_policy(env, gamma=1., tolerant=1e-6):\n",
    "    policy = np.ones((env.observation_space.n,\n",
    "            env.action_space.n)) / env.action_space.n # initialize\n",
    "    while True:\n",
    "        v = evaluate_policy(env, policy, gamma, tolerant)\n",
    "        if improve_policy(env, v, policy):\n",
    "            break\n",
    "    return policy, v\n",
    "\n",
    "policy_pi, v_pi = iterate_policy(env)\n",
    "logging.info('optimal state value =\\n%s', v_pi.reshape(4, 4))\n",
    "logging.info('optimal policy =\\n%s', np.argmax(policy_pi, axis=1).reshape(4, 4))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Test Policy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "00:00:00 [INFO] average episode reward = 0.77 ± 0.42\n"
     ]
    }
   ],
   "source": [
    "episode_rewards = [play_policy(env, policy_pi)  for _ in range(100)]\n",
    "logging.info('average episode reward = %.2f ± %.2f',\n",
    "        np.mean(episode_rewards), np.std(episode_rewards))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Value Iteration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "00:00:00 [INFO] optimal state value =\n",
      "[[0.82351232 0.82350671 0.82350281 0.82350083]\n",
      " [0.82351404 0.         0.52940011 0.        ]\n",
      " [0.82351673 0.82352018 0.76469779 0.        ]\n",
      " [0.         0.88234653 0.94117321 0.        ]]\n",
      "00:00:00 [INFO] optimal policy = \n",
      "[[0 3 3 3]\n",
      " [0 0 0 0]\n",
      " [3 1 0 0]\n",
      " [0 2 1 0]]\n"
     ]
    }
   ],
   "source": [
    "def iterate_value(env, gamma=1, tolerant=1e-6):\n",
    "    v = np.zeros(env.observation_space.n) # initialization\n",
    "    while True:\n",
    "        delta = 0\n",
    "        for state in range(env.observation_space.n):\n",
    "            vmax = max(v2q(env, v, state, gamma)) # update state function\n",
    "            delta = max(delta, abs(v[state]-vmax))\n",
    "            v[state] = vmax\n",
    "        if delta < tolerant: # check whether iterations can finish\n",
    "            break\n",
    "\n",
    "    # calculate optimal policy\n",
    "    policy = np.zeros((env.observation_space.n, env.action_space.n))\n",
    "    for state in range(env.observation_space.n):\n",
    "        action = np.argmax(v2q(env, v, state, gamma))\n",
    "        policy[state][action] = 1.\n",
    "    return policy, v\n",
    "\n",
    "policy_vi, v_vi = iterate_value(env)\n",
    "logging.info('optimal state value =\\n%s', v_vi.reshape(4, 4))\n",
    "logging.info('optimal policy = \\n%s', np.argmax(policy_vi, axis=1).reshape(4, 4))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Test Policy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "00:00:00 [INFO] average episode reward = 0.70 ± 0.46\n"
     ]
    }
   ],
   "source": [
    "episode_rewards = [play_policy(env, policy_vi) for _ in range(100)]\n",
    "logging.info('average episode reward = %.2f ± %.2f',\n",
    "        np.mean(episode_rewards), np.std(episode_rewards))"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
